# -*- coding: utf-8 -*-
"""
Created on Fri Feb 27 16:52:51 2026

@author: Wieczorek_W_Station
"""

import os
import re
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import random as rd
import torch
from tqdm import tqdm
import pickle
TF_ENABLE_ONEDNN_OPTS=0

root = "C:\\Users\\Wieczorek_W_Station\\Dropbox\\Arbeit Kassel\\paperideen\\Moltbook_Science\\Data\\"
path = os.path.join(root,"Molts")
output = os.path.join(root,"PreparedData")
try:
    os.makedirs(output)
except:
    pass

#%%
# =============================================================================
# define helper functions
# =============================================================================

def split_sentences(text):
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s.strip() for s in sentences if s.strip()]

def split_long_sentence_balanced(sentence, tokenizer, max_tokens=512):

    tokens = tokenizer.encode(sentence, add_special_tokens=False)
    total = len(tokens)

    if total <= max_tokens:
        return [sentence]

    # number of required pieces
    n_chunks = int(np.ceil(total / max_tokens))

    # ideal size
    target = total / n_chunks

    boundaries = [int(round(target * i)) for i in range(1, n_chunks)]

    pieces = []
    start = 0

    for b in boundaries:
        pieces.append(tokenizer.decode(tokens[start:b]))
        start = b

    pieces.append(tokenizer.decode(tokens[start:]))

    return pieces

def flatten_list(x):

    flat = []

    for item in x:
        if isinstance(item, list):
            flat.extend(flatten_list(item))
        else:
            flat.append(item)

    return flat

def balanced_chunks(sentences, cum_tokens, max_tokens=512):

    total = cum_tokens[-1]
    n_chunks = int(np.ceil(total / max_tokens))
    target = total / n_chunks

    boundaries = target * np.arange(1, n_chunks)

    split_idx = np.searchsorted(cum_tokens, boundaries)

    indices = np.concatenate(([0], split_idx, [len(sentences)]))

    chunks = [
        " ".join(sentences[indices[i]:indices[i+1]])
        for i in range(len(indices)-1)
    ]

    return chunks


def predict_sentiment(text):    
    inputs = sentimentTokenizer(text.lower(), return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = sentimentModel(**inputs)
        
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1).item()

    # sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}
    return predicted_class

#%%
os.chdir(path)
os.listdir()

threadsDf = pd.read_csv("ThreadsAll01.csv", sep = ";")\
    .drop(columns = "Unnamed: 0")
commentsDf = pd.read_csv("CommentsAll01.csv", sep = ";")\
    .drop(columns = "Unnamed: 0")

# =============================================================================
# initialize tokenizer and model
# =============================================================================
random_seed = 42
rd.seed(random_seed)

# Set a random seed for PyTorch (for GPU as well)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
# Load BERT tokenizer and model

tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')


sentimentTokenizer = AutoTokenizer.\
    from_pretrained("tabularisai/robust-sentiment-analysis",
                    # text.lower(),
                    return_tensors="pt", 
                    truncation=True, 
                    padding=True,
                    max_length=512)
sentimentModel = AutoModelForSequenceClassification.\
    from_pretrained("tabularisai/robust-sentiment-analysis")
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# embeddings = []
embeddingsDict = []
threads = threadsDf.content
threadsIds = threadsDf.id

max_size = 512
sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}


for thread, id_ in tqdm(zip(threads,threadsIds), total = len(threads)):
    
    ## split input text into sequences
    threadList = split_sentences(thread)
    
    ## split sentences that are longer than the max_size
    threadList = [split_long_sentence_balanced(s,tokenizer,max_size) for s in \
     threadList]
    
    ## flatten the list for further tokenization
    threadList = flatten_list(threadList)
    
    ## tokenize the sequences
    inputs = tokenizer(
        threadList,
        max_length = max_size,
        return_tensors="pt",
        padding="max_length").to(device)
 
    ## count the number of tokens, excluding all special tokens
    input_lenghts = []

    for i in inputs["input_ids"]:
        tokens = len([x for x in i if x not in [102,103,0]])
        input_lenghts.append(tokens)
    
    ## cumulate the number of tokens
    cum_tokens = np.cumsum(input_lenghts)
            
    ## divide the text into chunks of roughly the same size
    chunks = balanced_chunks(threadList, cum_tokens)    
        
    ## tokenize again
    inputs = tokenizer(
        chunks,
        max_length = max_size,
        truncation = True,
        return_tensors="pt",
        padding="max_length").to(device)
    
    ## get information on the tokenized texts    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    all_chunk_embeddings = []
     
    for i in range(len(input_ids)):  # iterate over chunks
    
        ## get the relevant information for BERTopic
        chunk_inputs = {
            "input_ids": input_ids[i].unsqueeze(0).to(device),
            "attention_mask": attention_mask[i].unsqueeze(0).to(device)
        }
     
        ## create embeddings
        with torch.no_grad():
            outputs = model(**chunk_inputs)
        
        all_chunk_embeddings.append(outputs["pooler_output"])
    
    
    ## conduct sentiment analysis
    sentiments = []
    for chunk in chunks:
        sentiment = predict_sentiment(chunk.lower())
        sentiments.append(sentiment)        
    ## calculate average sentiment & map on classes
    avgSentiment = sentiment_map[int(np.mean(sentiments).round(0))]
    
        
    outputs_dict = {"thread_id" : id_,
                    "content" : thread,
                    "split_content" : chunks,
                    "embeddings": all_chunk_embeddings,
                    "sentiment" : avgSentiment}
    # embeddings.append(all_chunk_embeddings)
    embeddingsDict.append(outputs_dict)

       

# =============================================================================
# Save embeddings
# =============================================================================
os.chdir(output)
# pickle.dump(embeddings, open("ThreadEmbeddings.pickle","wb"))
pickle.dump(embeddingsDict, open("ThreadEmbeddingsDict.pickle","wb"))

os.chdir(path)
threadsDf.insert(20,"sentiments", [d["sentiment"] for d in embeddingsDict])
threadsDf.to_csv("ThreadsAllSentiments.csv", sep = ";")
#%%
commentsDf.dropna(inplace = True)
# commentEmbeddings = []
CommentEmbeddingsDict = []

comments = commentsDf.comment
commentIds = commentsDf.post_id


max_size = 512

for thread, id_ in tqdm(zip(comments,commentIds), total = len(comments)):
    
    ## split input text into sequences
    threadList = split_sentences(thread)
    
    ## split sentences that are longer than the max_size
    threadList = [split_long_sentence_balanced(s,tokenizer,max_size) for s in \
     threadList]
    
    ## flatten the list for further tokenization
    threadList = flatten_list(threadList)
    
    ## tokenize the sequences
    inputs = tokenizer(
        threadList,
        max_length = max_size,
        return_tensors="pt",
        padding="max_length").to(device)
 
    ## count the number of tokens, excluding all special tokens
    input_lenghts = []

    for i in inputs["input_ids"]:
        tokens = len([x for x in i if x not in [102,103,0]])
        input_lenghts.append(tokens)
    
    ## cumulate the number of tokens
    cum_tokens = np.cumsum(input_lenghts)
            
    ## divide the text into chunks of roughly the same size
    chunks = balanced_chunks(threadList, cum_tokens)    
        
    ## tokenize again
    inputs = tokenizer(
        chunks,
        max_length = max_size,
        truncation = True,
        return_tensors="pt",
        padding="max_length").to(device)
    
    ## get information on the tokenized texts    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    all_chunk_embeddings = []
     
    for i in range(len(input_ids)):  # iterate over chunks
    
        ## get the relevant information for BERTopic
        chunk_inputs = {
            "input_ids": input_ids[i].unsqueeze(0).to(device),
            "attention_mask": attention_mask[i].unsqueeze(0).to(device)
        }
     
        ## create embeddings
        with torch.no_grad():
            outputs = model(**chunk_inputs)
        
        all_chunk_embeddings.append(outputs["pooler_output"])
    
    ## conduct sentiment analysis
    sentiments = []
    for chunk in chunks:
        sentiment = predict_sentiment(chunk.lower())
        sentiments.append(sentiment)        
    ## calculate average sentiment & map on classes
    avgSentiment = np.mean(sentiments)
    
    outputs_dict = {"thread_id" : id_,
                    "content" : thread,
                    "split_content" : chunks,
                    "embeddings": all_chunk_embeddings,
                    "sentiment" : avgSentiment}
    # embeddings.append(all_chunk_embeddings)
    CommentEmbeddingsDict.append(outputs_dict)


# =============================================================================
# save and export the dataframes & embeddings
# =============================================================================
os.chdir(output)
pickle.dump(CommentEmbeddingsDict, open("commentEmbeddingsDict.pickle","wb"))

os.chdir(path)
commentsDf.insert(3,"sentiments", [d["sentiment"] for d in CommentEmbeddingsDict])
commentsDf.to_csv("CommentsAllSentiments.csv", sep = ";")